use std::collections::BTreeMap;
use std::io::prelude::*;
use std::io::SeekFrom;
use std::str::FromStr;
use windows_bindgen::*;

fn main() {
    let platform = if let Some(platform) = option_env!("Platform") {
        match platform {
            "x86" => "i686_msvc",
            "x64" => "x86_64_msvc",
            "arm64" => "aarch64_msvc",
            _ => {
                println!("Unknown platform");
                return;
            }
        }
    } else {
        println!("Please run this tool from a Visual Studio command prompt");
        return;
    };

    let libraries = libraries();
    let output = std::path::PathBuf::from("crates/targets/baseline");
    _ = std::fs::remove_dir_all(&output);
    std::fs::create_dir_all(&output).unwrap();

    std::fs::write(
        output.join("readme.md"),
        r#"
This "baseline" is generated by `tool_msvc` and is only used to provide a visual diff of what
imports may have changed from one release to the next. The `lib.yml` workflow ensures that no
changes can sneak in undetected.
"#,
    )
    .unwrap();

    for (library, functions) in &libraries {
        build_library(&output, library, functions);
    }

    let mut cmd = std::process::Command::new("lib");
    cmd.current_dir(&output);
    cmd.arg("/nologo");
    cmd.arg("/Brepro");
    cmd.arg("/out:windows.lib");

    for library in libraries.keys() {
        cmd.arg(format!("{library}.lib"));
    }

    cmd.output().unwrap();

    for library in libraries.keys() {
        std::fs::remove_file(output.join(format!("{library}.lib"))).unwrap();
    }

    make_reproducible(&output.join("windows.lib"));

    std::fs::remove_dir_all(format!("crates/targets/{platform}/lib")).unwrap();
    std::fs::create_dir_all(format!("crates/targets/{platform}/lib")).unwrap();
    std::fs::rename(
        output.join("windows.lib"),
        format!(
            "crates/targets/{platform}/lib/windows.{}.lib",
            std::env!("CARGO_PKG_VERSION")
        ),
    )
    .unwrap();
}

fn build_library(
    output: &std::path::Path,
    library: &str,
    functions: &BTreeMap<String, CallingConvention>,
) {
    // Note that we don't use set_extension as it confuses PathBuf when the library name includes a period.

    let mut path = std::path::PathBuf::from(output);
    path.push(format!("{library}.c"));
    let mut c = std::fs::File::create(&path).unwrap();

    path.pop();
    path.push(format!("{library}.def"));
    let mut def = std::fs::File::create(&path).unwrap();

    def.write_all(
        format!(
            r#"
LIBRARY {library}
EXPORTS
"#
        )
        .as_bytes(),
    )
    .unwrap();

    for (function, calling_convention) in functions {
        let buffer = match calling_convention {
            CallingConvention::Stdcall(size) => {
                let mut buffer = format!("void __stdcall {function}(");

                for param in 0..(*size / 4) {
                    use std::fmt::Write;
                    write!(&mut buffer, "int p{param}, ").unwrap();
                }

                if buffer.ends_with(' ') {
                    buffer.truncate(buffer.len() - 2);
                }

                buffer.push_str(") {}\n");
                buffer
            }
            CallingConvention::Cdecl => {
                format!("void __cdecl {function}() {{}}\n")
            }
        };

        c.write_all(buffer.as_bytes()).unwrap();
        def.write_all(format!("{function}\n").as_bytes()).unwrap();
    }

    drop(c);
    drop(def);

    let mut cmd = std::process::Command::new("cl");
    cmd.current_dir(output);
    cmd.arg("/nologo");
    cmd.arg("/c");
    cmd.arg(format!("{library}.c"));
    cmd.output().unwrap();

    let mut cmd = std::process::Command::new("lib");
    cmd.current_dir(output);
    cmd.arg("/nologo");
    cmd.arg(format!("/out:{library}.lib"));
    cmd.arg(format!("/def:{library}.def"));
    cmd.arg(format!("{library}.obj"));
    cmd.output().unwrap();

    path.pop();
    path.push(format!("{library}.c"));
    // Don't remove the .c file as it represents the baseline.

    path.pop();
    path.push(format!("{library}.def"));
    std::fs::remove_file(&path).unwrap();

    path.pop();
    path.push(format!("{library}.exp"));
    std::fs::remove_file(&path).unwrap_or_else(|_| panic!("{:?}", path));

    path.pop();
    path.push(format!("{library}.obj"));
    std::fs::remove_file(&path).unwrap();
}

fn make_reproducible(lib: &std::path::Path) {
    /*
        The linker emits the following non-reproducible
        elements during lib generation:
          * archive member header timestamps (set to -1 via /Brepro)
          * import header time/date stamps
          * coff file header time/date stamps
          * compiler version coff symbols
        These must overwritten to ensure reproducibility
        across linker versions and time.
    */

    let mut archive = std::fs::File::options()
        .read(true)
        .write(true)
        .open(lib)
        .unwrap();
    let len = archive.metadata().unwrap().len();
    let mut header = [0u8; 8];
    archive.read_exact(&mut header).unwrap();
    assert_eq!(&header, b"!<arch>\n");
    loop {
        let mut buf = [0u8; 32];

        //
        // Archive member header
        //

        // Name
        archive.read_exact(&mut buf[..16]).unwrap();
        let skip_member = match &buf[..4] {
            b"/   " => true, // Linker member
            b"//  " => true, // Long names table
            _ => false,
        };

        // Timestamp, User ID, Group ID, and Mode
        archive.seek(SeekFrom::Current(12 + 6 + 6 + 8)).unwrap();

        // Size, sans header
        archive.read_exact(&mut buf[..10]).unwrap();
        let size = i64::from_str(
            std::str::from_utf8(buf.split(u8::is_ascii_whitespace).next().unwrap()).unwrap(),
        )
        .unwrap();

        // End of archive member header marker
        archive.read_exact(&mut buf[..2]).unwrap();
        assert_eq!(&buf[..2], [0x60, 0x0A]);

        if !skip_member {
            let start = archive.stream_position().unwrap();

            // Machine type
            archive.read_exact(&mut buf[..2]).unwrap();

            if buf[..2] == [0x00, 0x00] {
                //
                // Import header
                //

                // Signature 2
                archive.read_exact(&mut buf[..2]).unwrap();
                assert_eq!(&buf[..2], [0xFF, 0xFF]);

                // Version and Machine
                archive.seek(SeekFrom::Current(4)).unwrap();

                // Timestamp (zero out)
                archive.write_all(&[0u8; 4]).unwrap();

                // ... ignore rest
            } else {
                //
                // COFF File Header
                //
                // Section count
                archive.read_exact(&mut buf[..2]).unwrap();
                let sections = u16::from_le_bytes((&buf[..2]).try_into().unwrap());

                // Timestamp (zero out)
                archive.write_all(&[0u8; 4]).unwrap();

                // Symbol table offset (relative to end of archive member header)
                archive.read_exact(&mut buf[..4]).unwrap();
                let offset = u32::from_le_bytes((&buf[..4]).try_into().unwrap()) - 12;

                if offset != 0 {
                    //
                    // Symbol table
                    //
                    let last_position = archive.stream_position().unwrap();
                    archive.seek(SeekFrom::Current(offset.into())).unwrap();

                    // Name
                    archive.read_exact(&mut buf[..8]).unwrap();
                    let is_compiler_id = std::str::from_utf8(&buf[..8]).unwrap() == "@comp.id";

                    // Value
                    if is_compiler_id {
                        archive.write_all(&[0; 4]).unwrap();
                    }

                    archive.seek(SeekFrom::Start(last_position)).unwrap();
                }

                // ... and rest
                archive.seek(SeekFrom::Current(8)).unwrap();

                for _ in 1..sections {
                    //
                    // Section Header
                    //

                    // Name
                    archive.read_exact(&mut buf[..8]).unwrap();
                    let is_symbols = std::str::from_utf8(&buf[..8]).unwrap() == ".debug$S";

                    // Virtual size and address
                    archive.seek(SeekFrom::Current(8)).unwrap();

                    // Size of data
                    archive.seek(SeekFrom::Current(4)).unwrap();

                    // Pointer to raw data
                    archive.read_exact(&mut buf[..4]).unwrap();
                    let start_of_data = u32::from_le_bytes((&buf[..4]).try_into().unwrap())
                        + u32::try_from(start).unwrap();

                    // ... and rest
                    archive.seek(SeekFrom::Current(16)).unwrap();

                    if is_symbols {
                        //
                        // $$SYMBOLS
                        //

                        let last_position = archive.stream_position().unwrap();
                        archive.seek(SeekFrom::Start(start_of_data.into())).unwrap();

                        // Signature (CV_SIGNATURE_C11)
                        archive.read_exact(&mut buf[..4]).unwrap();
                        assert_eq!(&buf[..4], [0x02, 0x00, 0x00, 0x00]);

                        loop {
                            // Record length
                            archive.read_exact(&mut buf[..2]).unwrap();
                            let length = u16::from_le_bytes((&buf[..2]).try_into().unwrap());

                            // Index
                            archive.read_exact(&mut buf[..2]).unwrap();
                            let index = u16::from_le_bytes((&buf[..2]).try_into().unwrap());

                            // Skip non S_COMPILE2 records
                            if index != 0x1013 {
                                archive.seek(SeekFrom::Current(-2 + length as i64)).unwrap();
                                continue;
                            }

                            // Flags, Machine
                            archive.seek(SeekFrom::Current(4 + 2)).unwrap();

                            // Frontend major_minor_build and Backend major_minor_build (zero out version)
                            archive.write_all(&[0; 12]).unwrap();

                            archive.seek(SeekFrom::Start(last_position)).unwrap();
                            break;
                        }
                    }
                }
            }

            archive.seek(SeekFrom::Start(start)).unwrap();
        }

        if archive.seek(SeekFrom::Current((size + 1) & !1)).unwrap() >= len {
            break;
        }
    }
    archive.flush().unwrap();
    drop(archive);
}

#[test]
fn test_make_reproducible() {
    for (machine, offset) in [("x86", 0), ("x64", 12), ("arm64", 12)] {
        let mut def = std::fs::File::create("test.def").unwrap();
        def.write_all(
            r#"
LIBRARY long-library-name-for-placement-in-long-names-table.dll
EXPORTS
A=A.#1
B=B.#2
C=C.#3
"#
            .to_string()
            .as_bytes(),
        )
        .unwrap();
        drop(def);

        let mut cmd = std::process::Command::new("lib");
        cmd.arg("/nologo");
        cmd.arg("/Brepro");
        cmd.arg(format!("/machine:{machine}"));
        cmd.arg("/out:test.lib");
        cmd.arg("/def:test.def");
        cmd.output().unwrap();

        let mut archive = std::fs::File::options()
            .read(true)
            .write(true)
            .open("test.lib")
            .unwrap();

        let mut buf = [0; 24];

        make_reproducible(std::path::Path::new("test.lib"));

        // Archive member header timestamp
        archive.seek(SeekFrom::Start(0x18)).unwrap();
        archive.read_exact(&mut buf[..12]).unwrap();
        assert_eq!(&buf[..12], b"-1          ");

        // COFF file header time/date stamp
        archive.seek(SeekFrom::Start(0x322 - offset)).unwrap();
        archive.read_exact(&mut buf[..4]).unwrap();
        assert_eq!(&buf[..4], [0x00, 0x00, 0x00, 0x00]);

        // Compiler version symbol
        archive.seek(SeekFrom::Start(0x481 - offset)).unwrap();
        archive.read_exact(&mut buf[..12]).unwrap();
        assert_eq!(&buf[..12], b"@comp.id\0\0\0\0");

        // Import header time/date stamp
        archive.seek(SeekFrom::Start(0x493 - offset)).unwrap();
        archive.read_exact(&mut buf[..4]).unwrap();
        assert_eq!(&buf[..4], [0x00, 0x00, 0x00, 0x00]);

        // Sampling to ensure lib is consistent
        archive.seek(SeekFrom::Start(0xB4)).unwrap();
        archive.read_exact(&mut buf[..24]).unwrap();
        assert_eq!(&buf[..24], b"__NULL_IMPORT_DESCRIPTOR");

        archive.seek(SeekFrom::Start(0x26E - offset)).unwrap();
        archive.read_exact(&mut buf[..4]).unwrap();
        assert_eq!(&buf[..4], b"//  ");

        archive.seek(SeekFrom::Start(0x405 - offset)).unwrap();
        archive.read_exact(&mut buf[..18]).unwrap();
        assert_eq!(&buf[..18], b"Microsoft (R) LINK");

        drop(archive);

        std::fs::remove_file("test.lib").unwrap();
        std::fs::remove_file("test.exp").unwrap();
        std::fs::remove_file("test.def").unwrap();
    }
}
